import torch
import torch.nn as nn
from torch.nn import utils
from torch.nn import init
import torch.nn.functional as F


class LeNetDecoder(nn.Module):
  def __init__(self, num_outputs, num_features=64):
    super(LeNetDecoder, self).__init__()
    self.linear2 = nn.Linear(num_features, num_outputs)

  def forward(self, x):
    out = self.linear2(x)
    return out


class SNPDMNIST(nn.Module):

  def __init__(self, encoder, num_features, num_classes=0):
    super(SNPDMNIST, self).__init__()
    self.encoder = encoder
    self.num_features = num_features
    self.num_classes = num_classes

  def forward(self, *input):
    raise NotImplementedError


class SNPDFC3(SNPDMNIST):

  def __init__(self, encoder, num_features, num_classes=0, activation=F.relu):
    super(SNPDFC3, self).__init__(
      encoder=encoder, num_features=num_features,
      num_classes=num_classes)

    self.linear1 = utils.spectral_norm(nn.Linear(num_features, num_features))
    self.relu1 = nn.LeakyReLU(inplace=True, negative_slope=0.2)
    self.linear2 = utils.spectral_norm(nn.Linear(num_features, num_features))
    self.relu2 = nn.LeakyReLU(inplace=True, negative_slope=0.2)

    self.l7 = utils.spectral_norm(nn.Linear(num_features, 1))
    if num_classes > 0:
      self.l_y = utils.spectral_norm(
        nn.Embedding(num_classes, num_features))

    self._initialize()

  def _initialize(self):
    init.xavier_uniform_(self.l7.weight.data)
    optional_l_y = getattr(self, 'l_y', None)
    if optional_l_y is not None:
      init.xavier_uniform_(optional_l_y.weight.data)

  def forward(self, x, y=None):
    h = self.encoder(x)
    h = self.linear1(h)
    h = self.relu1(h)
    h = self.linear2(h)
    h = self.relu2(h)

    output = self.l7(h)
    if y is not None:
      output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
    return output
